import torch
import numpy as np
import pickle
import os
import torchvision
import random
cpath = os.path.dirname(__file__)
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from opacus import PrivacyEngine

import torch
import torchattacks
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.transforms import Resize, InterpolationMode


def MNIST_padding(num_padding, batch_size):
    transform = transforms.Compose([transforms.Pad(num_padding), Resize(size=(28, 28), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor()])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    images, labels = next(iter(trainloader))
    return trainloader, testloader


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(512, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out


class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature = nn.Sequential(
            # 1
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),  # 28*28->32*32-->28*28
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 14*14

            # 2
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),  # 10*10
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 5*5

        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=16 * 5 * 5, out_features=120),
            nn.Tanh(),
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=10),
        )

    def forward(self, x):
        return self.classifier(self.feature(x))


def model_training(model, train_dataloader, test_loader, num_iter, lr, sigma, max_per_sample_grad_norm, delta):
    original_optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0)
    privacy_engine = PrivacyEngine(secure_mode=False)
    model, optimizer, train_loader = privacy_engine.make_private(
        module=model,
        optimizer=original_optimizer,
        data_loader=train_dataloader,
        noise_multiplier=sigma,
        max_grad_norm=max_per_sample_grad_norm,
    )
    criterion = nn.CrossEntropyLoss()
    for iter in range(num_iter):
        model.train()
        running_loss = 0
        for _batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Iteration {iter + 1} loss: {running_loss / len(train_loader)}')

        model.eval()
        correct = 0
        total = 0
        test_loss = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.cuda(), labels.cuda()
                outputs = model(images)
                test_loss += criterion(outputs,labels)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        epsilon = privacy_engine.get_epsilon(delta)
        print('privacy budget: ', epsilon)
        test_loss = test_loss / len(test_loader)
        print('test_loss', test_loss)
        print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

        original_state_dict = model._module.state_dict()
        model_2 = LeNet().cuda()
        model_2.load_state_dict(original_state_dict)

        correct_adv = 0
        total_adv = 0
        test_loss_adv = 0
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()
            attack = torchattacks.PGD(model_2, eps=4 / 255, alpha=2 / 255, steps=40)
            adv_images = attack(images, labels).cuda()
            outputs = model(adv_images)
            test_loss_adv += criterion(outputs, labels) * labels.shape[0]
            _, predicted = torch.max(outputs.data, 1)
            total_adv += labels.shape[0]
            correct_adv += (predicted == labels).sum().item()
        test_acc_adv = correct_adv / total_adv
        test_loss_adv = test_loss_adv / total_adv
        print('test_loss_adv:', test_loss_adv)
        print('test_acc_adv:', test_acc_adv)
        epsilon = privacy_engine.get_epsilon(delta)
        print('privacy budget: ', epsilon)
    return test_loss, test_loss_adv

def main():
    num_padding = 43 ## 2-12%; 5-26%; 9-38%; 14-50%; 23-62%; 43-75%

    sigma = 0.925
    delta = 1e-5
    max_per_sample_grad_norm = 0.1

    num_iter = 50
    batch_size = 256
    lr = 7

    model = LeNet().cuda()
    train_loader, test_loader = MNIST_padding(num_padding,batch_size)
    test_loss = model_training(model, train_loader, test_loader, num_iter, lr, sigma, max_per_sample_grad_norm, delta)
    print('padding ratio', num_padding*2/(num_padding*2+28))
    print('test loss', test_loss)


if __name__ == "__main__":
    main()

#plt.figure(figsize=(4, 4))
#plt.imshow(images[0].permute(1, 2, 0))
#plt.title(f"Label: {trainset.classes[labels[0]]}")
#plt.axis('off')
#plt.show()